#!/usr/bin/env python3
# coding: utf-8

import keras
import matplotlib.pyplot as plt
from keras.callbacks import LambdaCallback
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM, Flatten
from keras.optimizers import RMSprop, Adam, Nadam,Adamax
from keras.utils.data_utils import get_file
import numpy as np
import random
import sys
import io
from keras.layers import Input, Dense, LSTM, RepeatVector, Reshape, Permute,Bidirectional
from keras.models import Model
from keras.layers import Bidirectional, concatenate, Conv1D, MaxPooling1D, GlobalMaxPooling1D, Dropout, BatchNormalization
import gc
from keras.utils import plot_model, np_utils
import time
import datetime
from keras.utils import to_categorical
from keras.callbacks import EarlyStopping
import warnings
from sklearn.model_selection import train_test_split
from numpy import random, mat
import multiprocessing
import os
import time
import math
from keras.utils import plot_model
warnings.filterwarnings('ignore')
from matplotlib import pyplot as plt
import itertools
from sklearn.metrics import confusion_matrix
from keras.models import load_model
from python_speech_features import mfcc
import pandas as pd
import numpy as np
from utils import read_wav


def plot_sonfusion_matrix(cm, classes, normalize=False, title='Confusion matrix',cmap=plt.cm.Blues):
    '''
    此函数为绘制混淆矩阵
    '''
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)
    if normalize:
        cm = cm.astype('float')/cm.sum(axis=1)[:,np.newaxis]
    thresh = cm.max()/2.0
    for i,j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j,i,cm[i,j], horizontalalignment='center',color='white' if cm[i,j] > thresh else 'black')
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predict label')
    subfile = os.path.join(path, 'confusion_matrix_test_0_1_5_kyle_test.jpg')
    plt.savefig(subfile, bbox_inches='tight')


def get_feature(fs ,signal):
    '''
    此函数用于提取mfcc特征，其中fs为采样率，signal为音频数据
    '''
    mfcc_feature = mfcc(signal, int(fs))
    #print("-------")
    #print(mfcc_feature)
    if len(mfcc_feature) == 0:
        print("ERROR.. failed to extract mfcc feature:", len(signal))
    return mfcc_feature


def deal_data(train_date_c1s):
    '''
    数据标准化
    '''
    train_date_c1s = train_date_c1s.astype(float)
    #train_date_rs = train_date_rs.astype(float)
    #trainrs_stat = to_categorical(train_date_rs, num_classes=3)#将标签作为one-hot

    vgac1 = np.mean(train_date_c1s)
    varc1 = math.sqrt(np.var(train_date_c1s))
    train_date_c1s = (train_date_c1s - vgac1) / varc1

    return train_date_c1s



if __name__ == '__main__':
    model = load_model('/home/banana/bear_voice/src/mfcc_tcnn_speaker/script/model/model_0499_acc=1.0.h5')
    #df = pd.read_table('/home/banana/bear_voice/src/mfcc_tcnn_speaker/script/human_txt/linkunling1/linkunling_5.txt', ,header=None).iloc[:,1:]
    #sig = np.array(df.values.T[0][1:], dtype = np.uint16)
    audio_path = 'out_audio.wav'
    fs,sig = read_wav(audio_path)
    df_feature = get_feature(16000, sig)
    df_feature = np.nan_to_num(df_feature)
    #print(df_feature.shape[0])
    mfcc_max = 495
    if df_feature.shape[0] < mfcc_max:
        sert = int(mfcc_max - df_feature.shape[0])
        zero_insert = np.array(np.zeros((sert, 13))).astype(float)
        df_feature = np.append(df_feature, zero_insert, axis=0)

    elif df_feature.shape[0] > mfcc_max:
        sert = int(df_feature.shape[0] - mfcc_max)
        #print(sert)
        for j in range(sert):
            df_feature = np.delete(df_feature, -1,axis=0)

    df_feature_scale = deal_data(df_feature)
    #print(df_feature_scale,df_feature_scale.shape)
    df_feature_scale = np.reshape(df_feature_scale, (1, mfcc_max, 13))
    pred_y = model.predict(df_feature_scale)
    pred_label = np.argmax(pred_y, axis=1)
    if pred_label == 0:
        speaker_label = 'unknown'
    else:
        speaker_label = 'linkunling'
    output_str = "speaker:  %s" % (speaker_label)
    print(output_str)



